"""TensorFlow interface for SciPy optimizers."""

# Code below is taken from
# https://github.com/tensorflow/tensorflow/blob/v1.15.2/tensorflow/contrib/opt/python/training/external_optimizer.py,
# because the ``tf.contrib`` module is not included in TensorFlow 2.

__all__ = ["ExternalOptimizerInterface", "ScipyOptimizerInterface"]

import numpy as np

from ...backend import tf


class ExternalOptimizerInterface:
    """Base class for interfaces with external optimization algorithms.
    Subclass this and implement `_minimize` in order to wrap a new optimization
    algorithm.
    `ExternalOptimizerInterface` should not be instantiated directly; instead use
    e.g. `ScipyOptimizerInterface`.
    @@__init__
    @@minimize
    """

    def __init__(
        self,
        loss,
        var_list=None,
        equalities=None,
        inequalities=None,
        var_to_bounds=None,
        **optimizer_kwargs
    ):
        """Initialize a new interface instance.
        Args:
          loss: A scalar `Tensor` to be minimized.
          var_list: Optional `list` of `Variable` objects to update to minimize
            `loss`.  Defaults to the list of variables collected in the graph
            under the key `GraphKeys.TRAINABLE_VARIABLES`.
          equalities: Optional `list` of equality constraint scalar `Tensor`s to be
            held equal to zero.
          inequalities: Optional `list` of inequality constraint scalar `Tensor`s
            to be held nonnegative.
          var_to_bounds: Optional `dict` where each key is an optimization
            `Variable` and each corresponding value is a length-2 tuple of
            `(low, high)` bounds. Although enforcing this kind of simple constraint
            could be accomplished with the `inequalities` arg, not all optimization
            algorithms support general inequality constraints, e.g. L-BFGS-B. Both
            `low` and `high` can either be numbers or anything convertible to a
            NumPy array that can be broadcast to the shape of `var` (using
            `np.broadcast_to`). To indicate that there is no bound, use `None` (or
            `+/- np.infty`). For example, if `var` is a 2x3 matrix, then any of
            the following corresponding `bounds` could be supplied:
            * `(0, np.infty)`: Each element of `var` held positive.
            * `(-np.infty, [1, 2])`: First column less than 1, second column less
              than 2.
            * `(-np.infty, [[1], [2], [3]])`: First row less than 1, second row less
              than 2, etc.
            * `(-np.infty, [[1, 2, 3], [4, 5, 6]])`: Entry `var[0, 0]` less than 1,
              `var[0, 1]` less than 2, etc.
          **optimizer_kwargs: Other subclass-specific keyword arguments.
        """
        self._loss = loss
        self._equalities = equalities or []
        self._inequalities = inequalities or []

        if var_list is None:
            self._vars = tf.trainable_variables()
        else:
            self._vars = list(var_list)

        packed_bounds = None
        if var_to_bounds is not None:
            left_packed_bounds = []
            right_packed_bounds = []
            for var in self._vars:
                shape = var.get_shape().as_list()
                bounds = (-np.infty, np.infty)
                if var in var_to_bounds:
                    bounds = var_to_bounds[var]
                left_packed_bounds.extend(list(np.broadcast_to(bounds[0], shape).flat))
                right_packed_bounds.extend(list(np.broadcast_to(bounds[1], shape).flat))
            packed_bounds = list(zip(left_packed_bounds, right_packed_bounds))
        self._packed_bounds = packed_bounds

        self._update_placeholders = [tf.placeholder(var.dtype) for var in self._vars]
        self._var_updates = [
            var.assign(tf.reshape(placeholder, _get_shape_tuple(var)))
            for var, placeholder in zip(self._vars, self._update_placeholders)
        ]

        loss_grads = _compute_gradients(loss, self._vars)
        equalities_grads = [
            _compute_gradients(equality, self._vars) for equality in self._equalities
        ]
        inequalities_grads = [
            _compute_gradients(inequality, self._vars)
            for inequality in self._inequalities
        ]

        self.optimizer_kwargs = optimizer_kwargs

        self._packed_var = self._pack(self._vars)
        self._packed_loss_grad = self._pack(loss_grads)
        self._packed_equality_grads = [
            self._pack(equality_grads) for equality_grads in equalities_grads
        ]
        self._packed_inequality_grads = [
            self._pack(inequality_grads) for inequality_grads in inequalities_grads
        ]

        dims = [_prod(_get_shape_tuple(var)) for var in self._vars]
        accumulated_dims = list(_accumulate(dims))
        self._packing_slices = [
            slice(start, end)
            for start, end in zip(accumulated_dims[:-1], accumulated_dims[1:])
        ]

    def minimize(
        self,
        session=None,
        feed_dict=None,
        fetches=None,
        step_callback=None,
        loss_callback=None,
        **run_kwargs
    ):
        """Minimize a scalar `Tensor`.
        Variables subject to optimization are updated in-place at the end of
        optimization.
        Note that this method does *not* just return a minimization `Op`, unlike
        `Optimizer.minimize()`; instead it actually performs minimization by
        executing commands to control a `Session`.
        Args:
          session: A `Session` instance.
          feed_dict: A feed dict to be passed to calls to `session.run`.
          fetches: A list of `Tensor`s to fetch and supply to `loss_callback`
            as positional arguments.
          step_callback: A function to be called at each optimization step;
            arguments are the current values of all optimization variables
            flattened into a single vector.
          loss_callback: A function to be called every time the loss and gradients
            are computed, with evaluated fetches supplied as positional arguments.
          **run_kwargs: kwargs to pass to `session.run`.
        """
        session = session or tf.get_default_session()
        feed_dict = feed_dict or {}
        fetches = fetches or []

        loss_callback = loss_callback or (lambda *fetches: None)
        step_callback = step_callback or (lambda xk: None)

        # Construct loss function and associated gradient.
        loss_grad_func = self._make_eval_func(
            [self._loss, self._packed_loss_grad],
            session,
            feed_dict,
            fetches,
            loss_callback,
        )

        # Construct equality constraint functions and associated gradients.
        equality_funcs = self._make_eval_funcs(
            self._equalities, session, feed_dict, fetches
        )
        equality_grad_funcs = self._make_eval_funcs(
            self._packed_equality_grads, session, feed_dict, fetches
        )

        # Construct inequality constraint functions and associated gradients.
        inequality_funcs = self._make_eval_funcs(
            self._inequalities, session, feed_dict, fetches
        )
        inequality_grad_funcs = self._make_eval_funcs(
            self._packed_inequality_grads, session, feed_dict, fetches
        )

        # Get initial value from TF session.
        initial_packed_var_val = session.run(self._packed_var)

        # Perform minimization.
        packed_var_val = self._minimize(
            initial_val=initial_packed_var_val,
            loss_grad_func=loss_grad_func,
            equality_funcs=equality_funcs,
            equality_grad_funcs=equality_grad_funcs,
            inequality_funcs=inequality_funcs,
            inequality_grad_funcs=inequality_grad_funcs,
            packed_bounds=self._packed_bounds,
            step_callback=step_callback,
            optimizer_kwargs=self.optimizer_kwargs,
        )
        var_vals = [
            packed_var_val[packing_slice] for packing_slice in self._packing_slices
        ]

        # Set optimization variables to their new values.
        session.run(
            self._var_updates,
            feed_dict=dict(zip(self._update_placeholders, var_vals)),
            **run_kwargs
        )

    def _minimize(
        self,
        initial_val,
        loss_grad_func,
        equality_funcs,
        equality_grad_funcs,
        inequality_funcs,
        inequality_grad_funcs,
        packed_bounds,
        step_callback,
        optimizer_kwargs,
    ):
        """Wrapper for a particular optimization algorithm implementation.
        It would be appropriate for a subclass implementation of this method to
        raise `NotImplementedError` if unsupported arguments are passed: e.g. if an
        algorithm does not support constraints but `len(equality_funcs) > 0`.
        Args:
          initial_val: A NumPy vector of initial values.
          loss_grad_func: A function accepting a NumPy packed variable vector and
            returning two outputs, a loss value and the gradient of that loss with
            respect to the packed variable vector.
          equality_funcs: A list of functions each of which specifies a scalar
            quantity that an optimizer should hold exactly zero.
          equality_grad_funcs: A list of gradients of equality_funcs.
          inequality_funcs: A list of functions each of which specifies a scalar
            quantity that an optimizer should hold >= 0.
          inequality_grad_funcs: A list of gradients of inequality_funcs.
          packed_bounds: A list of bounds for each index, or `None`.
          step_callback: A callback function to execute at each optimization step,
            supplied with the current value of the packed variable vector.
          optimizer_kwargs: Other key-value arguments available to the optimizer.
        Returns:
          The optimal variable vector as a NumPy vector.
        """
        raise NotImplementedError(
            "To use ExternalOptimizerInterface, subclass from it and implement "
            "the _minimize() method."
        )

    @classmethod
    def _pack(cls, tensors):
        """Pack a list of `Tensor`s into a single, flattened, rank-1 `Tensor`."""
        if not tensors:
            return None
        elif len(tensors) == 1:
            return tf.reshape(tensors[0], [-1])
        else:
            flattened = [tf.reshape(tensor, [-1]) for tensor in tensors]
            return tf.concat(flattened, 0)

    def _make_eval_func(self, tensors, session, feed_dict, fetches, callback=None):
        """Construct a function that evaluates a `Tensor` or list of `Tensor`s."""
        if not isinstance(tensors, list):
            tensors = [tensors]
        num_tensors = len(tensors)

        def eval_func(x):
            """Function to evaluate a `Tensor`."""
            augmented_feed_dict = {
                var: x[packing_slice].reshape(_get_shape_tuple(var))
                for var, packing_slice in zip(self._vars, self._packing_slices)
            }
            augmented_feed_dict.update(feed_dict)
            augmented_fetches = tensors + fetches

            augmented_fetch_vals = session.run(
                augmented_fetches, feed_dict=augmented_feed_dict
            )

            if callable(callback):
                callback(*augmented_fetch_vals[num_tensors:])

            return augmented_fetch_vals[:num_tensors]

        return eval_func

    def _make_eval_funcs(self, tensors, session, feed_dict, fetches, callback=None):
        return [
            self._make_eval_func(tensor, session, feed_dict, fetches, callback)
            for tensor in tensors
        ]


class ScipyOptimizerInterface(ExternalOptimizerInterface):
    """Wrapper allowing `scipy.optimize.minimize` to operate a `tf.compat.v1.Session`.
    Example:
    ```python
    vector = tf.Variable([7., 7.], 'vector')
    # Make vector norm as small as possible.
    loss = tf.reduce_sum(tf.square(vector))
    optimizer = ScipyOptimizerInterface(loss, options={'maxiter': 100})
    with tf.compat.v1.Session() as session:
      optimizer.minimize(session)
    # The value of vector should now be [0., 0.].
    ```
    Example with simple bound constraints:
    ```python
    vector = tf.Variable([7., 7.], 'vector')
    # Make vector norm as small as possible.
    loss = tf.reduce_sum(tf.square(vector))
    optimizer = ScipyOptimizerInterface(
        loss, var_to_bounds={vector: ([1, 2], np.infty)})
    with tf.compat.v1.Session() as session:
      optimizer.minimize(session)
    # The value of vector should now be [1., 2.].
    ```
    Example with more complicated constraints:
    ```python
    vector = tf.Variable([7., 7.], 'vector')
    # Make vector norm as small as possible.
    loss = tf.reduce_sum(tf.square(vector))
    # Ensure the vector's y component is = 1.
    equalities = [vector[1] - 1.]
    # Ensure the vector's x component is >= 1.
    inequalities = [vector[0] - 1.]
    # Our default SciPy optimization algorithm, L-BFGS-B, does not support
    # general constraints. Thus we use SLSQP instead.
    optimizer = ScipyOptimizerInterface(
        loss, equalities=equalities, inequalities=inequalities, method='SLSQP')
    with tf.compat.v1.Session() as session:
      optimizer.minimize(session)
    # The value of vector should now be [1., 1.].
    ```
    """

    _DEFAULT_METHOD = "L-BFGS-B"

    def _minimize(
        self,
        initial_val,
        loss_grad_func,
        equality_funcs,
        equality_grad_funcs,
        inequality_funcs,
        inequality_grad_funcs,
        packed_bounds,
        step_callback,
        optimizer_kwargs,
    ):
        def loss_grad_func_wrapper(x):
            # SciPy's L-BFGS-B Fortran implementation requires gradients as doubles.
            loss, gradient = loss_grad_func(x)
            return loss, gradient.astype("float64")

        optimizer_kwargs = dict(optimizer_kwargs.items())
        method = optimizer_kwargs.pop("method", self._DEFAULT_METHOD)

        constraints = []
        for func, grad_func in zip(equality_funcs, equality_grad_funcs):
            constraints.append({"type": "eq", "fun": func, "jac": grad_func})
        for func, grad_func in zip(inequality_funcs, inequality_grad_funcs):
            constraints.append({"type": "ineq", "fun": func, "jac": grad_func})

        minimize_args = [loss_grad_func_wrapper, initial_val]
        minimize_kwargs = {
            "jac": True,
            "callback": step_callback,
            "method": method,
            "constraints": constraints,
            "bounds": packed_bounds,
        }

        for kwarg in minimize_kwargs:
            if kwarg in optimizer_kwargs:
                if kwarg == "bounds":
                    # Special handling for 'bounds' kwarg since ability to specify bounds
                    # was added after this module was already publicly released.
                    raise ValueError(
                        "Bounds must be set using the var_to_bounds argument"
                    )
                raise ValueError(
                    "Optimizer keyword arg '{}' is set "
                    "automatically and cannot be injected manually".format(kwarg)
                )

        minimize_kwargs.update(optimizer_kwargs)

        import scipy.optimize  # pylint: disable=g-import-not-at-top

        result = scipy.optimize.minimize(*minimize_args, **minimize_kwargs)

        message_lines = [
            "Optimization terminated with:",
            "  Message: %s",
            "  Objective function value: %f",
        ]
        message_args = [result.message, result.fun]
        if hasattr(result, "nit"):
            # Some optimization methods might not provide information such as nit and
            # nfev in the return. Logs only available information.
            message_lines.append("  Number of iterations: %d")
            message_args.append(result.nit)
        if hasattr(result, "nfev"):
            message_lines.append("  Number of functions evaluations: %d")
            message_args.append(result.nfev)
        tf.logging.info("\n".join(message_lines), *message_args)

        return result["x"]


def _accumulate(list_):
    total = 0
    yield total
    for x in list_:
        total += x
        yield total


def _get_shape_tuple(tensor):
    return tuple(tensor.shape)


def _prod(array):
    prod = 1
    for value in array:
        prod *= value
    return prod


def _compute_gradients(tensor, var_list):
    grads = tf.gradients(tensor, var_list)
    # tf.gradients sometimes returns `None` when it should return 0.
    return [
        grad if grad is not None else tf.zeros_like(var)
        for var, grad in zip(var_list, grads)
    ]
